{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp models.ResCNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ResCNN\n",
    "\n",
    "> This is an unofficial PyTorch implementation by Ignacio Oguiza - oguiza@gmail.com based on:\n",
    "* Zou, X., Wang, Z., Li, Q., & Sheng, W. (2019). Integration of residual network and convolutional neural network along with various activation functions and global pooling for time series classification. Neurocomputing.\n",
    "* No official implementation found"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from tsai.imports import *\n",
    "from tsai.utils import *\n",
    "from tsai.models.layers import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.models.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ResCNNBlock(Module):\n",
    "    def __init__(self, ni, nf, kss=[7, 5, 3], coord=False, separable=False, zero_norm=False):\n",
    "        self.convblock1 = ConvBlock(ni, nf, kss[0], coord=coord, separable=separable)\n",
    "        self.convblock2 = ConvBlock(nf, nf, kss[1], coord=coord, separable=separable)\n",
    "        self.convblock3 = ConvBlock(nf, nf, kss[2], act=None, coord=coord, separable=separable, zero_norm=zero_norm)\n",
    "\n",
    "        # expand channels for the sum if necessary\n",
    "        self.shortcut = ConvBN(ni, nf, 1, coord=coord)\n",
    "        self.add = Add()\n",
    "        self.act = nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        res = x\n",
    "        x = self.convblock1(x)\n",
    "        x = self.convblock2(x)\n",
    "        x = self.convblock3(x)\n",
    "        x = self.add(x, self.shortcut(res))\n",
    "        x = self.act(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class ResCNN(Module):\n",
    "    def __init__(self, c_in, c_out, coord=False, separable=False, zero_norm=False):\n",
    "        nf = 64\n",
    "        self.block1 = ResCNNBlock(c_in, nf, kss=[7, 5, 3], coord=coord, separable=separable, zero_norm=zero_norm)\n",
    "        self.block2 = ConvBlock(nf, nf * 2, 3, coord=coord, separable=separable, act=nn.LeakyReLU, act_kwargs={'negative_slope':.2})\n",
    "        self.block3 = ConvBlock(nf * 2, nf * 4, 3, coord=coord, separable=separable, act=nn.PReLU)\n",
    "        self.block4 = ConvBlock(nf * 4, nf * 2, 3, coord=coord, separable=separable, act=nn.ELU, act_kwargs={'alpha':.3})\n",
    "        self.gap = nn.AdaptiveAvgPool1d(1)\n",
    "        self.squeeze = Squeeze(-1)\n",
    "        self.lin = nn.Linear(nf * 2, c_out)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.block1(x)\n",
    "        x = self.block2(x)\n",
    "        x = self.block3(x)\n",
    "        x = self.block4(x)\n",
    "        x = self.squeeze(self.gap(x))\n",
    "        return self.lin(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xb = torch.rand(16, 3, 10)\n",
    "test_eq(ResCNN(3,2,coord=True, separable=True)(xb).shape, [xb.shape[0], 2])\n",
    "test_eq(total_params(ResCNN(3,2))[0], 257283)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResCNN(\n",
       "  (block1): ResCNNBlock(\n",
       "    (convblock1): ConvBlock(\n",
       "      (0): AddCoords1d()\n",
       "      (1): SeparableConv1d(\n",
       "        (depthwise_conv): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), groups=4, bias=False)\n",
       "        (pointwise_conv): Conv1d(4, 64, kernel_size=(1,), stride=(1,), bias=False)\n",
       "      )\n",
       "      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (3): ReLU()\n",
       "    )\n",
       "    (convblock2): ConvBlock(\n",
       "      (0): AddCoords1d()\n",
       "      (1): SeparableConv1d(\n",
       "        (depthwise_conv): Conv1d(65, 65, kernel_size=(5,), stride=(1,), padding=(2,), groups=65, bias=False)\n",
       "        (pointwise_conv): Conv1d(65, 64, kernel_size=(1,), stride=(1,), bias=False)\n",
       "      )\n",
       "      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (3): ReLU()\n",
       "    )\n",
       "    (convblock3): ConvBlock(\n",
       "      (0): AddCoords1d()\n",
       "      (1): SeparableConv1d(\n",
       "        (depthwise_conv): Conv1d(65, 65, kernel_size=(3,), stride=(1,), padding=(1,), groups=65, bias=False)\n",
       "        (pointwise_conv): Conv1d(65, 64, kernel_size=(1,), stride=(1,), bias=False)\n",
       "      )\n",
       "      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (shortcut): ConvBlock(\n",
       "      (0): AddCoords1d()\n",
       "      (1): Conv1d(4, 64, kernel_size=(1,), stride=(1,), bias=False)\n",
       "      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (add): Add()\n",
       "    (act): ReLU()\n",
       "  )\n",
       "  (block2): ConvBlock(\n",
       "    (0): AddCoords1d()\n",
       "    (1): SeparableConv1d(\n",
       "      (depthwise_conv): Conv1d(65, 65, kernel_size=(3,), stride=(1,), padding=(1,), groups=65, bias=False)\n",
       "      (pointwise_conv): Conv1d(65, 128, kernel_size=(1,), stride=(1,), bias=False)\n",
       "    )\n",
       "    (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): LeakyReLU(negative_slope=0.2)\n",
       "  )\n",
       "  (block3): ConvBlock(\n",
       "    (0): AddCoords1d()\n",
       "    (1): SeparableConv1d(\n",
       "      (depthwise_conv): Conv1d(129, 129, kernel_size=(3,), stride=(1,), padding=(1,), groups=129, bias=False)\n",
       "      (pointwise_conv): Conv1d(129, 256, kernel_size=(1,), stride=(1,), bias=False)\n",
       "    )\n",
       "    (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): PReLU(num_parameters=1)\n",
       "  )\n",
       "  (block4): ConvBlock(\n",
       "    (0): AddCoords1d()\n",
       "    (1): SeparableConv1d(\n",
       "      (depthwise_conv): Conv1d(257, 257, kernel_size=(3,), stride=(1,), padding=(1,), groups=257, bias=False)\n",
       "      (pointwise_conv): Conv1d(257, 128, kernel_size=(1,), stride=(1,), bias=False)\n",
       "    )\n",
       "    (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (3): ELU(alpha=0.3)\n",
       "  )\n",
       "  (gap): AdaptiveAvgPool1d(output_size=1)\n",
       "  (squeeze): Squeeze()\n",
       "  (lin): Linear(in_features=128, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ResCNN(3,2,coord=True, separable=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0 BatchNorm1d     shape: [64]             mean: +1.000  std: +0.000\n",
      "  1 BatchNorm1d     shape: [64]             mean: +1.000  std: +0.000\n",
      "  2 BatchNorm1d     shape: [64]             mean: +0.000  std: +0.000\n",
      "  3 BatchNorm1d     shape: [64]             mean: +1.000  std: +0.000\n",
      "  4 BatchNorm1d     shape: [128]            mean: +1.000  std: +0.000\n",
      "  5 BatchNorm1d     shape: [256]            mean: +1.000  std: +0.000\n",
      "  6 BatchNorm1d     shape: [128]            mean: +1.000  std: +0.000\n"
     ]
    }
   ],
   "source": [
    "check_weight(ResCNN(3,2, zero_norm=True), is_bn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "out = create_scripts()\n",
    "beep(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
